import numpy as np
import TLSM_method as Tm
from Hamming_error import hamming_error
import tensorly as tl
import rpy2.robjects as robjects
import rpy2.robjects.numpy2ri
from spectral_aggregation import mean_adj
from hosvd_tucker import hosvd_tucker


rpy2.robjects.numpy2ri.activate()
robjects.r.source('LSM.r')
robjects.r.source('Codes_Spectral_Matrix.r')

Label = np.load('yeast_label.npy')
Cov = np.load('yeast_expression.npy')
L = Label-1
# Cov is of shape 205 * 80, columns 4 experiment
Cov_all = []
for i in range(4):
    Layer_i = np.arange(i, 80, 4)
    Cov_all.append(Cov[:, Layer_i])


def Sim(X_all, sigma=1.):
    Sim_ALL = []
    for X in X_all:
        Sim_X = np.zeros([X.shape[0], X.shape[0]])
        for i in range(X.shape[0]):
            for j in np.arange(i, X.shape[0]):
                Sq = np.linalg.norm(X[i] - X[j])
                # Sim_X[i, j] = scipy.special.logit(np.exp(-(Sq+1) / (2*sigma)))
                # Sim_X[i, j] = scipy.special.logit((np.exp(-Sq) + 1) / (2 * sigma))
                Sim_X[i, j] = np.exp(-Sq)
        Sim_X += np.triu(Sim_X, 1).T
        Sim_ALL.append(Sim_X)
    return np.array(Sim_ALL)


Net_sim = Sim(Cov_all)
Network = np.zeros([205, 205, 4])
for i in range(4):
    Network[:, :, i] = Net_sim[i, :, :]

quan = 0.6  # best quantile cut
Cutoff = np.quantile(Net_sim, quan)
for i in range(4):
    Network[:, :, i][Net_sim[i, :, :] <= Cutoff] = 0
    Network[:, :, i][Net_sim[i, :, :] > Cutoff] = 1


Model_Yao = Tm.TLSM(Network, 4)
Model_Yao.lambda_n = 1e-6
Model_Yao.check = 10
Model_Yao.eta = 100
Model_Yao.ite_num = 5000
Model_label = Model_Yao.training()
print("Hamming error of TLSM:", hamming_error([L, Model_label]))

M = 4
n = 205
K = 4
A_LSE = tl.unfold(Network, 2)
A_LSE = tl.fold(A_LSE, 0, (M, n, n))
psi_hat_lei = robjects.r.GetCluster(Net_sim, K)
psi_hat_lei = np.array(psi_hat_lei)-1
print("Hamming error of LSE", hamming_error([L, psi_hat_lei]))


psi_hat_mean_adj = mean_adj([Network, K])
print("Hamming error of mean adjacency matrix is", hamming_error([L, psi_hat_mean_adj]))

psi_hat_hosvd_tucker = hosvd_tucker([Network, K])
print("Hamming error of HOSVD Tucker is", hamming_error([L, psi_hat_hosvd_tucker]))

network_list = []
for i in range(4):
    network_list.append(Network[:, :, i])

hespck = np.zeros(100)
for i in range(100):
    hespck[i] = hamming_error([np.array(robjects.r.speck(network_list, n, K)) - 1, L])
print("Hamming error of Speck kernel is:", hespck.mean())

"""
Note:
HOSVD is not stable, we report the averaged experiments over 100 independent replications.
"""


